package org.hepeng.workx.spring.session.redis.serializer;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.type.TypeFactory;
import net.bytebuddy.dynamic.loading.ByteArrayClassLoader;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.io.FileUtils;
import org.joor.Reflect;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.support.DeserializingConverter;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.JdkSerializationRedisSerializer;

import java.io.File;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
 * @author he peng
 */
public class SafetyRedisSerializerUtils {

    public static JdkSerializationRedisSerializer safeJdkRedisSerializer(JdkSerializationRedisSerializer serializer) {
        DeserializingConverter converter = Reflect.on(serializer).get("deserializer");
        DefaultDeserializer deserializer = Reflect.on(converter).get("deserializer");
        ByteArrayClassLoader classLoader =
                new ByteArrayClassLoader(ClassLoader.getSystemClassLoader(), getLocalCacheTypeDefinitions());
        Reflect.on(deserializer).set("classLoader" , classLoader);
        return serializer;
    }

    public static GenericJackson2JsonRedisSerializer safeGenericJackson2JsonRedisSerializer(GenericJackson2JsonRedisSerializer serializer) {
        ObjectMapper mapper = Reflect.on(serializer).get("mapper");
        TypeFactory typeFactory = mapper.getTypeFactory();
        ByteArrayClassLoader classLoader =
                new ByteArrayClassLoader(ClassLoader.getSystemClassLoader(), getLocalCacheTypeDefinitions());
        Reflect.on(typeFactory).set("_classLoader" , classLoader);
        mapper.setTypeFactory(typeFactory);
        return serializer;
    }

    public static Map<String, byte[]> getLocalCacheTypeDefinitions() {
        Map<String, byte[]> typeDefinitions = new HashMap<>();
        File classesDirFile = new File(SafetyRedisSerializerConfiguration.CLASSES_DIR);
        if (classesDirFile.exists()) {
            Collection<File> classFiles = FileUtils.listFiles(classesDirFile , null, true);
            if (CollectionUtils.isNotEmpty(classFiles)) {
                File file = null;
                try {
                    for (File classFile : classFiles) {
                        file = classFile;
                        byte[] classBytes = FileUtils.readFileToByteArray(classFile);
                        typeDefinitions.put(classFile.getName() , classBytes);
                    }
                } catch (Throwable t) {
                    String msg = Objects.nonNull(file)
                            ? "read class file [" + file.getPath() + "] failed"
                            : "read class file failed";
                    throw new RuntimeException(msg , t);
                }
            }
        }
        return typeDefinitions;
    }
}
